#include "cal.h"
#include "cell.h"
#include "esmd_types.h"
#include "listed.hpp"
#include <cstdlib>
#include <qthread.h>
#include "memory_cpp.hpp"
#include "rigid_rec.hpp"
#include "utils.hpp"
#include <sys/cdefs.h>
#include <tuple>
#include <type_traits>
#include "swarch.h"
#include "cal.h"

const int MAX_PARAM_SIZE = 16384;
#ifdef __sw_host__
vec<real> *frep = NULL;
cal_lock_t *locks = NULL;
// template <int MaxParamSize>
void slave_compute_listed_forces_cpe(std::tuple<mdstat_t *, cellgrid_t *> *param);
void compute_listed_forces_sw(mdstat_t *stat, cellgrid_t *grid){
  if (frep == NULL) frep = esmd::allocate<vec<real>>(grid->nall.vol() * CELL_CAP, "listed/frep");
  if (locks == NULL) {
    locks = esmd::allocate<cal_lock_t>(grid->nall.vol() * CELL_CAP, "listed/locks");
    memset(locks, 0, grid->nall.vol() * CELL_CAP * sizeof(cal_lock_t));
  }
  auto param = std::make_tuple(stat, grid);
  qthread_spawn(slave_compute_listed_forces_cpe, &param);
  qthread_join();
}
void slave_topology_grids_export_cpe (cellgrid_t *grid);
void topology_grids_export_sw(cellgrid_t *grid){
  qthread_spawn(slave_topology_grids_export_cpe, grid);
  qthread_join();
}
void slave_topology_grids_import_cpe (cellgrid_t *grid);
void topology_grids_import_sw(cellgrid_t *grid){
  qthread_spawn(slave_topology_grids_import_cpe, grid);
  qthread_join();
}
#endif
#ifdef __sw_slave__
#include <qthread_slave.h>
#include "memptr.hpp"
#include "listed_impl.hpp"
#include "reg_reduce.h"
extern vec<real> *frep;
extern cal_lock_t *locks;

template <typename ParamType, int Capacity>
real listed_force_cell<ParamType, Capacity>::calc(vec<real> *f, vec<real> *x, ParamType *params) {
  listed_force_cell_ents<ParamType, int, Capacity> l_ents;
  l_ents.cnt = fetch_ptr(&by_id.cnt);
  dma_getn(by_id.entries, l_ents.entries, l_ents.cnt);
  return l_ents.calc(f, x, params, body_seq);
}
template <typename T> struct __get_base_type
{
};
template <typename T> struct __get_base_type<T*> {
  typedef T type;
};
template <typename T> struct __get_base_type<T*&> {
  typedef T type;
};
template<typename T>
using get_base_type = typename __get_base_type<T>::type;
template<int MaxParamSize>
struct param_bundle{
  char params[MaxParamSize];
  std::tuple<harmonic_bond_param *,
             harmonic_angle_param *,
             urey_bradly_param *,
             cosine_torsion_param *,
             ryckaert_bellmans_param *,
             extended_ryckaert_bellmans_param *,
             harmonic_improper_param *> pointers;
  char *cur;
  template<typename T>
  static __always_inline T *align(char *ptr){
    return (T*)((long)ptr + alignof(T) - 1 & ~(alignof(T) - 1));
  }
  param_bundle(topology_grids *topo) {
    cur = params;
    get_param(topo, utils::make_iseq<int, topology_grids::TOP_CALC_END>{});
  }
  template<int ...Is>
  void get_param(topology_grids *topo, utils::iseq<int, Is...> seq){
    utils::expand{
      (get_param<Is>(topo), 0)...
        };
  }
  template<int I>
  void get_param(topology_grids *topo){
    if (topo->present[I]) {
      std::get<I>(pointers) = align<get_base_type<decltype(std::get<I>(pointers))>>(cur);
      dma_getn(topo->get<I>().param, std::get<I>(pointers), topo->get<I>().nparam);
      cur = ((char*)std::get<I>(pointers)) + sizeof(*std::get<I>(pointers)) * topo->get<I>().nparam;
    }
  }
  template<int I>
  __always_inline auto &get(){
    return std::get<I>(pointers);
  }
};
#define LWPF_UNIT U(LISTED)
#define LWPF_KERNELS K (COMPUTE) K(CALC) K(GUEST_X) K(GUEST_F) K(EXPORT) K(IMPORT_ALL) K(IMPORT) K(SEARCH) K(JTAG) K(SPECIAL) K(GUEST) K(RELINK) K(RIGID) 
#define EVT_PC0 PC0_CNT_INST
#define EVT_PC1 PC1_CYC_DATA_REL_BLOCK
#define EVT_PC2 PC2_CNT_GLD
#define EVT_PC3 PC3_CYCLE
#include "lwpf3/lwpf.h"
void compute_listed_forces_cpe(std::tuple<mdstat_t *, cellgrid_t *> *param){
  lwpf_enter(LISTED);
  lwpf_start(COMPUTE);

  mdstat_t stat = fetch_ptr(std::get<0>(*param));
  cellgrid_t grid = fetch_ptr(std::get<1>(*param));
  topology_grids topo = fetch_ptr(grid.topo);
  cell_x_cache<32, 8> xcache(grid.cells);
  cell_f_cache<32, 8> fcache(frep, locks);
  fcache.fill(&grid);
  param_bundle<MAX_PARAM_SIZE> params(&topo);
  //if (_MYID == 0)
  //printf("%p %p %p %p\n", params.get<TOP_HARMONIC_BOND>(), params.get<TOP_UREY_BRADLY>(), params.get<TOP_COSINE_TORSION>(), params.get<TOP_HARMONIC_IMPROPER>());
  int nneighbor = hdcell(grid.nn, grid.nn, grid.nn, grid.nn) + 1;
  if (_MYID == 0){
    *lbal_cnt = 0;
  }
  asm volatile("memb\n\t" ::: "memory");
  qthread_syn();
  //if (_MYID == 0)
  FOREACH_LOCAL_CELL_CPE_DYN(&grid, cx, cy, cz, icell){
    int ioff = get_offset_xyz<true>(grid, cx, cy, cz);
    cellmeta_t imeta = fetch_ptr((cellmeta_t*)&icell->basis);
    auto x = array_in(icell->x, icell->natom);
    auto f = array_inout(icell->f, icell->natom);
    auto first_guest_cell = array_in(icell->first_guest_cell, nneighbor + 1);
    auto guest_id = array_in(icell->guest_id, imeta.nguest);
    lwpf_start(GUEST_X);
    FOREACH_NEIGHBOR(&grid, cx, cy, cz, dx, dy, dz, jcell){
      int did = hdcell(grid.nn, dx, dy, dz);
      if (first_guest_cell[did] == first_guest_cell[did + 1]) continue;
      int joff = get_offset_xyz<true>(grid, cx + dx, cy + dy, cz + dz);
      cellmeta_t jmeta = fetch_ptr((cellmeta_t*)&jcell->basis);
      vec<real> dcell = jmeta.basis - imeta.basis;
      for (int i = first_guest_cell[did]; i < first_guest_cell[did + 1]; i ++){
        x[CELL_CAP + i] = xcache(joff, guest_id[i]) + dcell;
        f[CELL_CAP + i] = 0;
      }
    }
    lwpf_stop(GUEST_X);

    lwpf_start(CALC);
    if (topo.present[TOP_HARMONIC_BOND]) stat.ebond += topo.get<TOP_HARMONIC_BOND>().cells[ioff].calc(f, x, params.get<TOP_HARMONIC_BOND>());
    if (topo.present[TOP_HARMONIC_ANGLE]) stat.eangle += topo.get<TOP_HARMONIC_ANGLE>().cells[ioff].calc(f, x, params.get<TOP_HARMONIC_ANGLE>());
    if (topo.present[TOP_UREY_BRADLY]) stat.eangle += topo.get<TOP_UREY_BRADLY>().cells[ioff].calc(f, x, params.get<TOP_UREY_BRADLY>());
    if (topo.present[TOP_COSINE_TORSION]) stat.etori += topo.get<TOP_COSINE_TORSION>().cells[ioff].calc(f, x, params.get<TOP_COSINE_TORSION>());
    if (topo.present[TOP_RYCKAERT_BELLMANS]) stat.ebond += topo.get<TOP_RYCKAERT_BELLMANS>().cells[ioff].calc(f, x, params.get<TOP_RYCKAERT_BELLMANS>());
    if (topo.present[TOP_EXTENDED_RYCKAERT_BELLMANS]) stat.ebond += topo.get<TOP_EXTENDED_RYCKAERT_BELLMANS>().cells[ioff].calc(f, x, params.get<TOP_EXTENDED_RYCKAERT_BELLMANS>());
    if (topo.present[TOP_HARMONIC_IMPROPER]) stat.eimpr += topo.get<TOP_HARMONIC_IMPROPER>().cells[ioff].calc(f, x, params.get<TOP_HARMONIC_IMPROPER>());
    lwpf_stop(CALC);
    // for (int i = 0; i < icell->nguest; i ++){
    //   f[i + CELL_CAP] = 0;
    // }
    lwpf_start(GUEST_F);
    FOREACH_NEIGHBOR(&grid, cx, cy, cz, dx, dy, dz, jcell){
      int did = hdcell(grid.nn, dx, dy, dz);
      if (first_guest_cell[did] == first_guest_cell[did + 1]) continue;
      int joff = get_offset_xyz<true>(grid, cx + dx, cy + dy, cz + dz);
      cellmeta_t jmeta = fetch_ptr((cellmeta_t*)&jcell->basis);
      vec<real> dcell = jmeta.basis - imeta.basis;
      for (int i = first_guest_cell[did]; i < first_guest_cell[did + 1]; i ++){
        fcache(joff, guest_id[i]) += f[CELL_CAP + i];
      }
    }
    lwpf_stop(GUEST_F);
  }

  fcache.flush();
  fcache.sum(&grid);
  double flat_stat[4] = {stat.ebond, stat.eangle, stat.etori, stat.eimpr};
  reg_reduce_inplace_doublev4(flat_stat, 1);
  if (_MYID == 0){
    stat.ebond = flat_stat[0];
    stat.eangle = flat_stat[1];
    stat.etori = flat_stat[2];
    stat.eimpr = flat_stat[3];
    dma_putn(std::get<0>(*param), &stat, 1);
  }
  lwpf_stop(COMPUTE);
  lwpf_exit(LISTED);
}

#include "hmini.hpp"
template <typename ParamType, typename IdType, int Capacity>
template<typename TagMapType> 
void listed_force_cell_ents<ParamType, IdType, Capacity>::export_entries(const TagMapType &tag_map) {
  listed_force_cell_ents<ParamType, IdType, Capacity> ldm_ents;
  list_get(this, &ldm_ents);
  int ptr_keep = 0, ptr_move = Capacity;
  for (int i = 0; i < ldm_ents.cnt; i++) {
    if (tag_map.contains(ldm_ents.entries[i].id[ParamType::owner])) {
      ldm_ents.entries[ptr_keep++] = ldm_ents.entries[i];
    } else {
      ldm_ents.entries[--ptr_move] = ldm_ents.entries[i];
    }
  }
  ldm_ents.cnt = ptr_keep;
  ldm_ents.nexport = Capacity - ptr_move;
  dma_putn(this, (char*)&ldm_ents, ldm_ents.effective_size());
  dma_putn(this->entries + ptr_move, ldm_ents.entries + ptr_move, ldm_ents.nexport);
}

void topology_grids_export_cpe(cellgrid_t *g_grid) {
  lwpf_start(EXPORT);
  cellgrid_t grid = fetch_ptr(g_grid);
  topology_grids topo = fetch_ptr(grid.topo);
  if (_MYID == 0)
    *lbal_cnt = 0;
  qthread_syn();
  FOREACH_LOCAL_CELL_CPE_DYN(&grid, cx, cy, cz, icell){
    mini_hset<tagint, CELL_CAP * 8 / 3> tagset;
    auto imeta = fetch_ptr((cellmeta_t*)&icell->basis);
    auto tags = array_in(icell->tag, imeta.natom);
    for (int i = 0; i < imeta.natom; i ++){
      tagset.insert(tags[i]);
    }
    
    topo.do_export(get_offset_xyz<true>(grid, cx, cy, cz), tagset);
  }
  lwpf_stop(EXPORT);
}

template <int I, typename TagMapType, typename GuestMapType>
__attribute__((noinline)) void topology_grids::do_import_neighbors(int self, TagMapType &tagmap, GuestMapType &guestmap, cellgrid_t *grid, int cx, int cy, int cz){
  if (!present[I]) return;
  //puts(__PRETTY_FUNCTION__);
  auto cells = get<I>().cells;
  auto &itags = cells[self].by_tag;
  typedef std::remove_reference_t<decltype(itags)> cell_ent_type;
  cell_ent_type litopo;
  list_get(&itags, &litopo);
  int nexport0 = litopo.nexport;
  FOREACH_NEARNEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell){
    if (dx == 0 && dy == 0 && dz == 0) continue;
    int joff = get_offset_xyz<true>(grid, cx + dx, cy + dy, cz + dz);
    dma_getn(&cells[joff].by_tag.nexport, &litopo.nexport, 1);
    if (litopo.nexport == 0) continue;
    int start = cell_ent_type::cap - litopo.nexport;
    dma_getn(cells[joff].by_tag.entries +  start, litopo.entries + start, litopo.nexport);
    litopo.import_entries(tagmap, litopo);
  }
  litopo.nexport = nexport0;
  //printf("%d %d %d %d\n", I, TOP_END, TOP_GUEST_END, litopo.cnt);
  //printf("%d\n", litopo.effective_size());
  list_put(&itags, &litopo);
  //puts("finding guests");
  if (I < TOP_GUEST_END) {
    litopo.find_guests(guestmap, tagmap);
  }
}

template <typename ParamType, int Capacity>
template <typename TagMapType, int... Is>
__attribute__((noinline)) void listed_force_cell<ParamType, Capacity>::relink(const TagMapType &tag_map, utils::iseq<int, Is...> seq) {
  typedef std::remove_reference_t<decltype(by_id.entries[0])> entry_type;
  list_writer<std::remove_reference_t<decltype(by_id.entries[0])>, 64> id_list(&by_id.entries[0]);
  std::remove_reference_t<decltype(by_tag)> tag_list;
  list_get(&by_tag, &tag_list);
  for (int i = 0; i < tag_list.cnt; i++) {
    entry_type ent;
    std::tie(ent.id[Is]..., ent.pid) = std::tie(tag_map[tag_list.entries[i].id[Is]]..., tag_list.entries[i].pid);
    id_list.append(ent);
    // if (by_tag.entries[i].pid > 1000) {
    //   printf("%d %d %d\n", by_tag.entries[i].id[0], by_tag.entries[i].id[1], i);
    //   exit(1);
    // }
  }
  dma_putn(&by_id.cnt, &tag_list.cnt, 1);
  //by_id.cnt = tag_list.cnt;
};
template<typename TagMapType, typename GuestMapType>
__attribute__((noinline)) int neighbor_search(cellgrid_t *grid, topology_grids &topo, celldata_t *icell, tagint *tags, TagMapType &tagmap, GuestMapType &guest_set, int cx, int cy, int cz){
  auto &special_cell = topo.get<special_pair>().cells[get_offset_xyz<true>(grid, cx, cy, cz)];
  std::remove_reference_t<decltype(special_cell.by_tag)> special_tags;
  list_get(&special_cell.by_tag, &special_tags);
  typedef decltype(special_tags.entries[0]) special_ent;
  int guest_offset = CELL_CAP;
  int nexcl = 0, nscal = 0;
  // long sp;
  // asm volatile("ldi %0, 0($30)\n\t" :"=r"(sp));
  // printf("%d\n", sp);
  int nneighbor = hdcell(grid->nn, grid->nn, grid->nn, grid->nn) + 1;
  auto first_guest_cell = array_out(icell->first_guest_cell, nneighbor + 1);
  auto first_excl_cell = array_out(icell->first_excl_cell, nneighbor + 1);
  auto first_scal_cell = array_out(icell->first_scal_cell, nneighbor + 1);
  list_writer<int[2], 128> excl_list(icell->excl_id);
  list_writer<int[2], 128> scal_list(icell->scal_id);
  memptr_t<int, CELL_CAP, MP_OUT> guest_t(icell->t + CELL_CAP, 0);
  memptr_t<int, CELL_CAP, MP_OUT> guest_id(icell->guest_id, 0);
  memptr_t<real, CELL_CAP, MP_OUT> guest_rmass(icell->rmass + CELL_CAP, 0);
  FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
    int did = hdcell(grid->nn, dx, dy, dz);
    auto jmeta = fetch_ptr((cellmeta_t*)&jcell->basis);
    first_guest_cell[did] = guest_offset - CELL_CAP;
    first_excl_cell[did] = nexcl;
    first_scal_cell[did] = nscal;
    if (dx > 1 || dy > 1 || dz > 1) continue;
    lwpf_start(JTAG);
    mini_htab<tagint, int, CELL_CAP * 4 / 3> jtab;
    memptr_t<tagint, CELL_CAP, MP_IN> jtag(jcell->tag, jmeta.natom);
    bool has_guest = false;
    for (int j = 0; j < jmeta.natom; j ++) {
      jtab[jtag[j]] = j;
      if (guest_set.contains(jtag[j])) has_guest = true;
    }
    lwpf_stop(JTAG);
    lwpf_start(SPECIAL);
    for (auto &special_ent : special_tags){
      if (jtab.contains(special_ent.id[1])) {
        int i = tagmap[special_ent.id[0]];
        int j = jtab[special_ent.id[1]];
        if (special_ent.pid == special_pair::EXCL){
          excl_list.append((int[2]){i, j});
          nexcl ++;
        } else if (special_ent.pid == special_pair::SCAL) {
          if (icell == jcell && i > j) continue;
          scal_list.append((int[2]){i, j});
          nscal ++;
        }
      }
    }
    lwpf_stop(SPECIAL);

    if (icell == jcell || !has_guest) continue;
    lwpf_start(GUEST);
    memptr_t<int, CELL_CAP, MP_IN> jt(jcell->t, jmeta.natom);
    memptr_t<real, CELL_CAP, MP_IN> jrmass(jcell->rmass, jmeta.natom);
    for (int j = 0; j < jmeta.natom; j ++){
      if (guest_set.contains(jtag[j])) {
        guest_t[guest_offset - CELL_CAP] = jt[j];
        guest_rmass[guest_offset - CELL_CAP] = jrmass[j];
        guest_id[guest_offset - CELL_CAP] = j;
        tags[guest_offset] = jtag[j];
        guest_offset ++;
      }
    }
    lwpf_stop(GUEST);
  }
  guest_t.count = guest_offset - CELL_CAP;
  guest_id.count = guest_offset - CELL_CAP;
  guest_rmass.count = guest_offset - CELL_CAP;
  //icell->nguest = guest_offset - CELL_CAP;
  int nguest = guest_offset - CELL_CAP;
  dma_putn(&icell->nguest, &nguest, 1);
  dma_putn(icell->tag + CELL_CAP, tags + CELL_CAP, nguest);
  first_guest_cell[hdcell(grid->nn, grid->nn, grid->nn, grid->nn+1)] = guest_offset - CELL_CAP;
  first_excl_cell[hdcell(grid->nn, grid->nn, grid->nn, grid->nn+1)] = nexcl;
  first_scal_cell[hdcell(grid->nn, grid->nn, grid->nn, grid->nn+1)] = nscal;
  return nguest;
}
void topology_grids_import_cpe(cellgrid_t *g_grid){
  lwpf_enter(LISTED);
  lwpf_start(IMPORT_ALL);
  cellgrid_t lgrid = fetch_ptr(g_grid);
  cellgrid_t *grid = &lgrid;
  topology_grids topo = fetch_ptr(grid->topo);
  //if (_MYID == 0)
  if (_MYID == 0)
    *lbal_cnt = 0;
  qthread_syn();
  FOREACH_LOCAL_CELL_CPE_DYN(grid, cx, cy, cz, icell) {
    cellmeta_t imeta = fetch_ptr((cellmeta_t*)&icell->basis);
    auto tag = array_in(icell->tag, imeta.natom);
    mini_htab<tagint, int, CELL_CAP * 8 / 3> tagmap;
    for (int i = 0; i < imeta.natom; i ++){
      tagmap[tag[i]] = i;
    }
    int ioff = get_offset_xyz<true>(grid, cx, cy, cz);
    mini_hset<tagint, MAX_CELL_GUEST * 4 / 3> guest_set;

    lwpf_start(IMPORT);
    topo.do_import_neighbors(ioff, tagmap, guest_set, grid, cx, cy, cz, utils::make_iseq<int, TOP_END>{});
    lwpf_stop(IMPORT);

    lwpf_start(SEARCH);
    int nguest = neighbor_search(grid, topo, icell, &tag[0], tagmap, guest_set, cx, cy, cz);
    lwpf_stop(SEARCH);
    assert(nguest == guest_set.size());

    lwpf_start(RELINK);
    for (int i = 0; i < nguest; i ++){
      tagmap[tag[CELL_CAP + i]] = CELL_CAP + i;
    }
    topo.relink(get_offset_xyz<true>(grid, cx, cy, cz), tagmap);
    lwpf_stop(RELINK);
    
    lwpf_start(RIGID);
    if (topo.present[TOP_RIGID]){
      auto &rigid_topo = topo.get<rigid_param>().cells[get_offset_xyz<true>(grid, cx, cy, cz)];
      decltype(rigid_topo.by_id) rigids_list;
      list_get(&rigid_topo.by_id, &rigids_list);
      auto rigids_cell = array_out(icell->rigid, imeta.natom);
      rigid_param* g_params = topo.get<rigid_param>().param;
      int nparam = topo.get<rigid_param>().nparam;
      //auto params = array_in(g_params, nparam);
      assert(nparam < 256);
      memptr_t<rigid_param, 256, MP_IN> params(g_params);
      for (int i = 0; i < imeta.natom; i ++){
        rigids_cell[i].type = 0;
      }
      for (int i = 0; i < rigids_list.cnt; i ++){
        auto &ent = rigids_list.entries[i];
        int owner = ent.id[0];
        auto &param = params[ent.pid];//topo.get<rigid_param>().param[ent.pid];
        //auto &rig = icell->rigid[owner];

        rigids_cell[owner].id[0] = ent.id[1];
        rigids_cell[owner].id[1] = ent.id[2];
        rigids_cell[owner].id[2] = ent.id[3];
        rigids_cell[owner].type = param.type - 1;
        rigids_cell[owner].r0[0] = param.r0[0];
        rigids_cell[owner].r0[1] = param.r0[1];
        rigids_cell[owner].r0[2] = param.r0[2];
      }
    }
    lwpf_stop(RIGID);
  }

  lwpf_stop(IMPORT_ALL);
  lwpf_exit(LISTED);
}
#endif
